import torch
from transformers import BertModel, BertTokenizer, BertConfig


class Config:
    def __init__(self, root_path):
        self.model_name = 'bert'
        self.root_path = root_path
        # self.origin_train = self.root_path + '/data/origin_data/train.csv'
        # self.origin_test = self.root_path + '/data/origin_data/test.csv'
        # self.origin_dev = self.root_path + '/data/origin_data/dev.csv'

        # self.train_path = self.root_path + '/data/clean_data/trainClean.csv'
        # self.test_path = self.root_path + '/data/clean_data/testClean.csv'
        # self.dev_path = self.root_path + '/data/clean_data/devClean.csv'

        self.class_path = self.root_path + '/data/clean_data/categories.json'
        #停止词
        self.stopwords_path = self.root_path + '/data/origin_data/stopwords.txt'

        self.save_model = self.root_path + '/models/'


        # 模型参数
        self.num_cat_classes = 10
        self.embed_dim = 768
        self.max_len = 300
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.batch_size = 16
        self.hidden_size = 2048

        self.epochs = 5
        self.lr = 2e-5


        self.bert_path = self.root_path + '/models/bert-base-chinese'
        self.bert_model = BertModel.from_pretrained(self.bert_path)  # 加载预训练BERT模型
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)  # BERT模型的分词器
        self.bert_config = BertConfig.from_pretrained(self.bert_path)  # BERT模型的配置


        # 基模型配置文件
        # 模型保存路径
        self.basic_model_path = self.root_path + '/src/basic_model/model/'
        # 图片路径
        self.final_img = self.root_path + '/src/basic_model/img_data/final.png'
        self.thinking_img = self.root_path + '/src/basic_model/img_data/think.png'
        self.img_path = self.root_path + '/src/dataEDA/img/'
        self.data_path = self.root_path + '/src/bert/data/'


        # 数据清洗后文件路径
        self.clean_file = self.root_path + '/data/clean_data/online_cleaned_shopping.csv'

        # 数据清洗前文件路径
        self.fast_file = self.root_path + '/data/origin_data/online_shopping_10_cats.csv'
        # 训练集路径
        self.train_path = self.root_path + '/data/clean_data/trainClean.csv'
        # 测试集路径
        self.test_path = self.root_path + '/data/clean_data/testClean.csv'
